package com.snail.common.lock;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.snail.common.core.text.Convert;
import com.snail.common.core.utils.StringUtils;
import com.snail.common.lock.config.LockConfig;
import com.snail.common.redis.service.RedisService;
import lombok.SneakyThrows;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * @Description: 分布式锁
 * @Author: Snail
 * @CreateDate: 2023/8/22 17:17
 * @Version: V1.0
 */
@Component
public class Lock {

    private static final TransmittableThreadLocal<Map<String, String>> THREAD_LOCAL = new TransmittableThreadLocal<>();
    @Autowired
    private RedisService redisService;
    @Autowired
    private LockConfig lockConfig;

    /**
     * 获取分布式锁
     *
     * @param lock 锁标识
     * @return 结果
     */
    @SneakyThrows
    public boolean getLock(String lock) {
        String uuid = UUID.randomUUID().toString();
        //当前时间
        long timeMillis = System.currentTimeMillis();
        //获取分布式锁
        Boolean result = false;
        //分布式锁是否超时
        while (System.currentTimeMillis() - timeMillis < lockConfig.getWaitTimeOut()) {
            result = redisService.getLock(lock, uuid, lockConfig.getLockTimeOut(), TimeUnit.SECONDS);
            //获取分布式成功
            if (result) {
                break;
            }
            //休眠100毫秒
            Thread.sleep(100L);
        }
        if (!result) {
            throw new Exception(StringUtils.format("获取分布式锁超时,等待时间{},超时时间{}", (System.currentTimeMillis() - timeMillis), lockConfig.getWaitTimeOut()));
        }
        //设置线程共享变量
        set(lock, uuid);
        return true;
    }

    /**
     * 释放锁
     *
     * @param lock 锁标识
     * @return 结果
     */
    public boolean unLock(String lock) {
        // key 是自己才可以释放，不是就不能释放别人的锁
        String script = "if redis.call(\"get\",KEYS[1]) == ARGV[1] then " +
                        "    return redis.call(\"del\",KEYS[1]) " +
                        " else " +
                        "    return 0 " +
                        "end";
        RedisScript<Boolean> redisScript = RedisScript.of(script,Boolean.class);
        List<String> keys = Collections.singletonList(lock);
        // 执行脚本的时候传递的 value 就是对应的值
        Boolean executeScript = redisService.executeScript(redisScript, keys, get(lock));
        if(executeScript){
            //删除线程共享内存中的值
            remove(lock);
        }
        return executeScript;
    }

    public static void set(String key, String value) {
        Map<String, String> map = getLocalMap();
        map.put(key, value == null ? StringUtils.EMPTY : value);
    }

    public static String get(String key) {
        Map<String, String> map = getLocalMap();
        return Convert.toStr(map.getOrDefault(key, StringUtils.EMPTY));
    }

    public static String remove(String key) {
        Map<String, String> map = getLocalMap();
        return map.remove(key);
    }

    public static Map<String, String> getLocalMap() {
        Map<String, String> map = THREAD_LOCAL.get();
        if (map == null) {
            map = new ConcurrentHashMap<>(64);
            THREAD_LOCAL.set(map);
        }
        return map;
    }
}
